//InstDecode.scala
package mycore

import chisel3._
import chisel3.util._

class InstDecode extends Module{
  val io = IO(new Bundle{

    val ENABLE = Input(Bool())
    val ifid = Flipped(new IFID)
    val idex = new IDEX
    val cancel = Input(Bool())
    val wbdata = Flipped(new WBDATA)
    val gpr = Output(Vec(32, UInt(64.W)))
    
  })

  //本流水级需要使用的信号
  val ids = RegEnable(io.ifid, WireInit(0.U.asTypeOf(new IFID())), io.ENABLE)
  val Valid = ids.Valid
  val Inst = ids.Inst
  val PC = ids.PC

  /*------------------------------*/
  /*             本体             */
  /*------------------------------*/
  val rs1 = Inst(19,15)
  val rs2 = Inst(24,20)
  val regdes = Inst(11,7)

  val decoder = Module(new Decoder)
  decoder.io.inst := Inst
  val isa = decoder.io.isa
  val imm = decoder.io.imm
  val wen = decoder.io.wen

  val regfile = new RegFile
  val src1_hit = io.wbdata.wen && (rs1 === io.wbdata.regdes) && (rs1 =/= 0.U)
  val src2_hit = io.wbdata.wen && (rs2 === io.wbdata.regdes) && (rs2 =/= 0.U)
  val src1 = Mux(src1_hit, io.wbdata.data, regfile.read(rs1))
  val src2 = Mux(src2_hit, io.wbdata.data, regfile.read(rs2))

  when(io.ENABLE) { 
    regfile.write(io.wbdata.wen, io.wbdata.regdes, io.wbdata.data) 
  }

  /*------------------------------*/

  //传递给下一个流水级的信号
  io.idex.Valid   := Mux(io.cancel, false.B, Valid)
  io.idex.Inst    := Inst
  io.idex.PC      := PC
  io.idex.isa     := isa
  io.idex.src1    := src1
  io.idex.src2    := src2
  io.idex.imm     := imm
  io.idex.wen     := wen
  io.idex.regdes  := regdes

  //difftest
  io.gpr := regfile.gpr

}